library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.1 ──
## ✓ ggplot2 3.3.5 ✓ purrr 0.3.4
## ✓ tibble 3.1.6 ✓ dplyr 1.0.7
## ✓ tidyr 1.1.4 ✓ stringr 1.4.0
## ✓ readr 2.1.1 ✓ forcats 0.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
library(ggplot2)
fix_ntopic
res_fix_ntopic = read.csv("../data/res_fix_ntopic.csv")
res_sum = apply(res_fix_ntopic, 1, function(x) c(mean(x[16:20]), mean(x[21:25]), sd(x[16:20]), sd(x[21:25])))
res_fix_ntopic = res_fix_ntopic %>%
mutate(prior = apply(res_fix_ntopic[1:2], 1, function(x) paste("(", as.character(x[1]), ", ", as.character(x[2]),")", sep = ""))) %>%
select(d, k0, k1, nlabel, a, supervised, unsupervised,prior) %>%
mutate(supervised_mean = res_sum[1,], unsupervised_mean = res_sum[2,],
supervised_sd = res_sum[3,], unsupervised_sd = res_sum[4,])
res_fix_ntopic %>%
pivot_longer(supervised:unsupervised, names_to = "method", values_to = "accuracy") %>%
ggplot(aes(x=nlabel,y=accuracy,color=log(a))) + geom_point(size=1) +
labs(x="number of classes") + facet_grid(method ~ d,space ="free_x")

ggsave("figure1_fix_ntopic_dimension.png", dpi = 150, width = 6, height = 4)
res_fix_ntopic %>%
pivot_longer(supervised:unsupervised, names_to = "method", values_to = "accuracy") %>%
ggplot(aes(x=nlabel,y=accuracy,color=method))+
geom_point()+labs(x="number of unshared topics")+
facet_grid(prior ~ d,space ="free_x")

ggsave("figure2_fix_ntopic_prior.png", dpi = 150, width = 6, height = 4)
# just for comparing
res_fix_ntopic %>%
pivot_longer(supervised_mean:unsupervised_mean, names_to = "method", values_to = "accuracy") %>%
ggplot(aes(x=nlabel,y=accuracy,color=log(a)))+geom_point(size=1)+facet_grid(method ~ d,space ="free_x")

res_fix_ntopic %>%
pivot_longer(supervised_sd:unsupervised_sd, names_to = "method", values_to = "accuracy") %>%
ggplot(aes(x=nlabel,y=accuracy,color=log(a)))+geom_point(size=1)+facet_grid(method ~ d,space ="free_x")

res_fix_ntopic %>%
pivot_longer(supervised_mean:unsupervised_mean, names_to = "method", values_to = "accuracy") %>%
ggplot(aes(x=nlabel,y=accuracy,color=method))+
geom_point()+labs(x="number of unshared topics")+
facet_grid(prior ~ d,space ="free_x")

res_fix_ntopic %>%
pivot_longer(supervised_sd:unsupervised_sd, names_to = "method", values_to = "accuracy") %>%
ggplot(aes(x=nlabel,y=accuracy,color=method))+
geom_point()+labs(x="number of unshared topics")+
facet_grid(prior ~ d,space ="free_x")

d100
tem = read.csv("../data/res_fix_ntopic.csv") %>% filter(d == 100)
res_d100 = read.csv("../data/res_d100.csv") %>% rbind(tem)
res_sum = apply(res_d100 ,1, function(x) c(mean(x[16:20]), mean(x[21:25]), sd(x[16:20]), sd(x[21:25])))
res_d100 = res_d100 %>%
mutate(prior = apply(res_d100[1:2], 1, function(x) paste("(", as.character(x[1]), ", ", as.character(x[2]),")", sep = ""))) %>%
select(k0, k1, nlabel, a, supervised, unsupervised,prior) %>%
mutate(supervised_mean = res_sum[1,], unsupervised_mean = res_sum[2,],
supervised_sd = res_sum[3,], unsupervised_sd = res_sum[4,])
names(res_d100)[1:2] = c("class_specific_topics", "shared_topics")
res_d100 %>%
pivot_longer(supervised:unsupervised, names_to = "method", values_to = "accuracy") %>%
ggplot(aes(x=nlabel,y=accuracy,color=method)) +
geom_point(size=1) +
labs(x="number of classes") +
facet_grid(shared_topics ~ class_specific_topics, labeller = label_both)

ggsave("figure3_d100_k0_k1.png", dpi = 150, width = 6, height = 4)
random
res_random = read.csv("../data/res_random.csv")
res_sum = apply(res_random ,1, function(x) c(mean(x[16:20]), mean(x[21:25]), sd(x[16:20]), sd(x[21:25])))
res_random = res_random %>%
mutate(prior = apply(res_random[1:2], 1, function(x) paste("(", as.character(x[1]), ", ", as.character(x[2]),")", sep = ""))) %>%
select(alpha_p, beta_p, d, k0, k1, nlabel, a, supervised, unsupervised,prior) %>%
mutate(supervised_mean = res_sum[1,], unsupervised_mean = res_sum[2,],
supervised_sd = res_sum[3,], unsupervised_sd = res_sum[4,])%>%
mutate(var_p = alpha_p*beta_p/((alpha_p+beta_p)**2*(alpha_p+beta_p)))
res_random %>%
pivot_longer(supervised:unsupervised, names_to = "method", values_to = "accuracy") %>%
ggplot(aes(x=log(var_p),y=accuracy, color=method)) +
geom_point(size=1)

ggsave("figure4_random_var_p.png", dpi = 150, width = 6, height = 4)
res=res_random %>%
pivot_longer(supervised:unsupervised, names_to = "method", values_to = "accuracy")
ggplot(res,aes(x=k0,y=accuracy,color=method))+geom_point()+labs(x="number of unshared topics")

ggplot(res,aes(x=k1,y=accuracy,color=method))+geom_point()+labs(x="number of shared topics")

ggplot(res,aes(x=nlabel*k0-k1,y=accuracy,color=method))+geom_point()+labs(x="k0*nlabel-k1")

ggplot(res,aes(x=d,y=accuracy,color=method))+geom_point()+labs(x="dimension of data")

ggplot(res,aes(x=nlabel,y=accuracy,color=method))+geom_point()+labs(x="number of labels")
